# Copyright (C) 2016-2018  Mikel Artetxe <artetxem@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import embeddings
from cupy_utils import *
from src.logger import create_logger

import argparse
import collections
import logging
import numpy as np
import re
import sys
import time


logger = create_logger()

def dropout(m, p):
    if p <= 0.0:
        return m
    else:
        xp = get_array_module(m)
        mask = xp.random.rand(*m.shape) >= p
        return m*mask


def topk_mean(m, k, inplace=False):  # TODO Assuming that axis is 1
    xp = get_array_module(m)
    # topk_ids = xp.argpartition(m, -k, axis=1)[:, -k:]
    # topk_vals = xp.take_along_axis(m, topk_ids, axis=1)
    n = m.shape[0]
    ans = xp.zeros(n, dtype=m.dtype)
    if k <= 0:
        return ans
    if not inplace:
        m = xp.array(m)
    ind0 = xp.arange(n)
    ind1 = xp.empty(n, dtype=int)
    minimum = m.min()
    for i in range(k):
        m.argmax(axis=1, out=ind1)
        ans += m[ind0, ind1]
        m[ind0, ind1] = minimum
    return ans / k
    return xp.mean(topk_vals, axis=1)

def parsing():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Map word embeddings in two languages into a shared space')
    parser.add_argument('src_input', help='the input source embeddings')
    parser.add_argument('trg_input', help='the input target embeddings')
    parser.add_argument('src_output', help='the output source embeddings')
    parser.add_argument('trg_output', help='the output target embeddings')
    parser.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)')
    parser.add_argument('--precision', choices=['fp16', 'fp32', 'fp64'], default='fp32', help='the floating-point precision (defaults to fp32)')
    parser.add_argument('--cuda', action='store_true', help='which cuda to be used(requires cupy), -1 is default to cpu')
    parser.add_argument('--batch_size', default=10000, type=int, help='batch size (defaults to 10000); does not affect results, larger is usually faster but uses more memory')
    parser.add_argument('--seed', type=int, default=0, help='the random seed (defaults to 0)')

    recommended_group = parser.add_argument_group('recommended settings', 'Recommended settings for different scenarios')
    recommended_type = recommended_group.add_mutually_exclusive_group()
    recommended_type.add_argument('--supervised', metavar='DICTIONARY', help='recommended if you have a large training dictionary')
    recommended_type.add_argument('--semi_supervised', metavar='DICTIONARY', help='recommended if you have a small seed dictionary')
    recommended_type.add_argument('--identical', action='store_true', help='recommended if you have no seed dictionary but can rely on identical words')
    recommended_type.add_argument('--unsupervised', action='store_true', help='recommended if you have no seed dictionary and do not want to rely on identical words')
    recommended_type.add_argument('--acl2018', action='store_true', help='reproduce our ACL 2018 system')
    recommended_type.add_argument('--aaai2018', metavar='DICTIONARY', help='reproduce our AAAI 2018 system')
    recommended_type.add_argument('--acl2017', action='store_true', help='reproduce our ACL 2017 system with numeral initialization')
    recommended_type.add_argument('--acl2017_seed', metavar='DICTIONARY', help='reproduce our ACL 2017 system with a seed dictionary')
    recommended_type.add_argument('--emnlp2016', metavar='DICTIONARY', help='reproduce our EMNLP 2016 system')

    init_group = parser.add_argument_group('advanced initialization arguments', 'Advanced initialization arguments')
    init_type = init_group.add_mutually_exclusive_group()
    init_type.add_argument('-d', '--init_dictionary', default=sys.stdin.fileno(), metavar='DICTIONARY', help='the training dictionary file (defaults to stdin)')
    init_type.add_argument('--init_identical', action='store_true', help='use identical words as the seed dictionary')
    init_type.add_argument('--init_numerals', action='store_true', help='use latin numerals (i.e. words matching [0-9]+) as the seed dictionary')
    init_type.add_argument('--init_unsupervised', action='store_true', help='use unsupervised initialization')
    init_group.add_argument('--unsupervised_vocab', type=int, default=0, help='restrict the vocabulary to the top k entries for unsupervised initialization')

    mapping_group = parser.add_argument_group('advanced mapping arguments', 'Advanced embedding mapping arguments')
    mapping_group.add_argument('--normalize', choices=['unit', 'center', 'unitdim', 'centeremb', 'none'], nargs='*', default=[], help='the normalization actions to perform in order')
    mapping_group.add_argument('--whiten', action='store_true', help='whiten the embeddings')
    mapping_group.add_argument('--src_reweight', type=float, default=0, nargs='?', const=1, help='re-weight the source language embeddings')
    mapping_group.add_argument('--trg_reweight', type=float, default=0, nargs='?', const=1, help='re-weight the target language embeddings')
    mapping_group.add_argument('--src_dewhiten', choices=['src', 'trg'], help='de-whiten the source language embeddings')
    mapping_group.add_argument('--trg_dewhiten', choices=['src', 'trg'], help='de-whiten the target language embeddings')
    mapping_group.add_argument('--dim_reduction', type=int, default=0, help='apply dimensionality reduction')
    mapping_type = mapping_group.add_mutually_exclusive_group()
    mapping_type.add_argument('-c', '--orthogonal', action='store_true', help='use orthogonal constrained mapping')
    mapping_type.add_argument('-u', '--unconstrained', action='store_true', help='use unconstrained mapping')

    self_learning_group = parser.add_argument_group('advanced self-learning arguments', 'Advanced arguments for self-learning')
    self_learning_group.add_argument('--self_learning', action='store_true', help='enable self-learning')
    self_learning_group.add_argument('--vocabulary_cutoff', type=int, default=0, help='restrict the vocabulary to the top k entries')
    self_learning_group.add_argument('--direction', choices=['forward', 'backward', 'union'], default='union', help='the direction for dictionary induction (defaults to union)')
    self_learning_group.add_argument('--csls', type=int, nargs='?', default=0, const=10, metavar='NEIGHBORHOOD_SIZE', dest='csls_neighborhood', help='use CSLS for dictionary induction')
    self_learning_group.add_argument('--threshold', default=0.000001, type=float, help='the convergence threshold (defaults to 0.000001)')
    self_learning_group.add_argument('--validation', default=None, metavar='DICTIONARY', help='a dictionary file for validation at each iteration')
    self_learning_group.add_argument('--stochastic_initial', default=0.1, type=float, help='initial keep probability stochastic dictionary induction (defaults to 0.1)')
    self_learning_group.add_argument('--stochastic_multiplier', default=2.0, type=float, help='stochastic dictionary induction multiplier (defaults to 2.0)')
    self_learning_group.add_argument('--stochastic_interval', default=50, type=int, help='stochastic dictionary induction interval (defaults to 50)')
    self_learning_group.add_argument('--log', help='write to a log file in tsv format at each iteration')
    self_learning_group.add_argument('-v', '--verbose', action='store_true', help='write log information to stderr at each iteration')
    args = parser.parse_args()

    if args.supervised is not None:
        parser.set_defaults(init_dictionary=args.supervised, normalize=['unit', 'center', 'unit'], whiten=True, src_reweight=0.5, trg_reweight=0.5, src_dewhiten='src', trg_dewhiten='trg', batch_size=1000)
    if args.semi_supervised is not None:
        parser.set_defaults(init_dictionary=args.semi_supervised, normalize=['unit', 'center', 'unit'], whiten=True, src_reweight=0.5, trg_reweight=0.5, src_dewhiten='src', trg_dewhiten='trg', self_learning=True, vocabulary_cutoff=20000, csls_neighborhood=10)
    if args.identical:
        parser.set_defaults(init_identical=True, normalize=['unit', 'center', 'unit'], whiten=True, src_reweight=0.5, trg_reweight=0.5, src_dewhiten='src', trg_dewhiten='trg', self_learning=True, vocabulary_cutoff=20000, csls_neighborhood=10)
    if args.unsupervised or args.acl2018:
        parser.set_defaults(init_unsupervised=True, unsupervised_vocab=4000, normalize=['unit', 'center', 'unit'], whiten=True, src_reweight=0.5, trg_reweight=0.5, src_dewhiten='src', trg_dewhiten='trg', self_learning=True, vocabulary_cutoff=20000, csls_neighborhood=10)
    if args.aaai2018:
        parser.set_defaults(init_dictionary=args.aaai2018, normalize=['unit', 'center'], whiten=True, trg_reweight=1, src_dewhiten='src', trg_dewhiten='trg', batch_size=1000)
    if args.acl2017:
        parser.set_defaults(init_numerals=True, orthogonal=True, normalize=['unit', 'center'], self_learning=True, direction='forward', stochastic_initial=1.0, stochastic_interval=1, batch_size=1000)
    if args.acl2017_seed:
        parser.set_defaults(init_dictionary=args.acl2017_seed, orthogonal=True, normalize=['unit', 'center'], self_learning=True, direction='forward', stochastic_initial=1.0, stochastic_interval=1, batch_size=1000)
    if args.emnlp2016:
        parser.set_defaults(init_dictionary=args.emnlp2016, orthogonal=True, normalize=['unit', 'center'], batch_size=1000)
    args = parser.parse_args()

    return args

def build_seed_dictionary(args, ctx):
    # Build the seed dictionary
    src_indices = []
    trg_indices = []
    xp = get_array_module(ctx.x)
    if args.init_unsupervised:
        sim_size = min(ctx.x.shape[0], ctx.z.shape[0]) if args.unsupervised_vocab <= 0\
            else min(ctx.x.shape[0], ctx.z.shape[0], args.unsupervised_vocab)
        u, s, vt = xp.linalg.svd(ctx.x[:sim_size], full_matrices=False)
        xsim = (u*s).dot(u.T)
        u, s, vt = xp.linalg.svd(ctx.z[:sim_size], full_matrices=False)
        zsim = (u*s).dot(u.T)
        del u, s, vt
        xsim.sort(axis=1)
        zsim.sort(axis=1)
        embeddings.normalize(xsim, args.normalize)
        embeddings.normalize(zsim, args.normalize)
        sim = xsim.dot(zsim.T)
        if args.csls_neighborhood > 0:
            knn_sim_fwd = topk_mean(sim, k=args.csls_neighborhood)
            knn_sim_bwd = topk_mean(sim.T, k=args.csls_neighborhood)
            sim -= knn_sim_fwd[:, xp.newaxis]/2 + knn_sim_bwd/2
        if args.direction == 'forward':
            src_indices = xp.arange(sim_size)
            trg_indices = sim.argmax(axis=1)
        elif args.direction == 'backward':
            src_indices = sim.argmax(axis=0)
            trg_indices = xp.arange(sim_size)
        elif args.direction == 'union':
            src_indices = xp.concatenate((xp.arange(sim_size), sim.argmax(axis=0)))
            trg_indices = xp.concatenate((sim.argmax(axis=1), xp.arange(sim_size)))
        del xsim, zsim, sim
    elif args.init_numerals:
        numeral_regex = re.compile('^[0-9]+$')
        src_numerals = {word for word in ctx.src_words if numeral_regex.match(word) is not None}
        trg_numerals = {word for word in ctx.trg_words if numeral_regex.match(word) is not None}
        numerals = src_numerals.intersection(trg_numerals)
        for word in numerals:
            src_indices.append(ctx.src_word2ind[word])
            trg_indices.append(ctx.trg_word2ind[word])
    elif args.init_identical:
        identical = set(ctx.src_words).intersection(set(ctx.trg_words))
        for word in identical:
            src_indices.append(ctx.src_word2ind[word])
            trg_indices.append(ctx.trg_word2ind[word])
    else:
        f = open(args.init_dictionary, encoding=args.encoding, errors='surrogateescape')
        for line in f:
            src, trg = line.split()
            try:
                src_ind = ctx.src_word2ind[src]
                trg_ind = ctx.trg_word2ind[trg]
                src_indices.append(src_ind)
                trg_indices.append(trg_ind)
            except KeyError:
                logger.info('WARNING: OOV dictionary entry ({0} - {1})'.format(src, trg))
    return src_indices, trg_indices

def dataload(args, ctx):
    # Check command line arguments
    if (args.src_dewhiten is not None or args.trg_dewhiten is not None) and not args.whiten:
        logger.warning('ERROR: De-whitening requires whitening first')
        sys.exit(-1)

    # Choose the right dtype for the desired precision
    if args.precision == 'fp16':
        ctx.dtype = 'float16'
    elif args.precision == 'fp32':
        ctx.dtype = 'float32'
    elif args.precision == 'fp64':
        ctx.dtype = 'float64'

    # Read input embeddings
    srcfile = open(args.src_input, encoding=args.encoding, errors='surrogateescape')
    trgfile = open(args.trg_input, encoding=args.encoding, errors='surrogateescape')
    ctx.src_words, ctx.x = embeddings.read(srcfile, threshold=200000, dtype=ctx.dtype)
    ctx.trg_words, ctx.z = embeddings.read(trgfile, threshold=200000, dtype=ctx.dtype)

    # NumPy/CuPy management
    if args.cuda:
        if not supports_cupy():
            logger.info('ERROR: Install CuPy for CUDA support')
            sys.exit(-1)
        ctx.xp = get_cupy()
        ctx.x = ctx.xp.asarray(ctx.x)
        ctx.z = ctx.xp.asarray(ctx.z)
    else:
        ctx.xp = np

    ctx.xp.random.seed(args.seed)

    # Build word to index map
    ctx.src_word2ind = {word: i for i, word in enumerate(ctx.src_words)}
    ctx.trg_word2ind = {word: i for i, word in enumerate(ctx.trg_words)}

    # STEP 0: Normalization
    embeddings.normalize(ctx.x, args.normalize)
    embeddings.normalize(ctx.z, args.normalize)
    src_indices, trg_indices = build_seed_dictionary(args, ctx)

    # Read validation dictionary
    if args.validation is not None:
        f = open(args.validation, encoding=args.encoding, errors='surrogateescape')
        ctx.validation = collections.defaultdict(set)
        oov = set()
        vocab = set()
        for line in f:
            words = line.split()
            if len(words) != 2:
                continue
            src, trg = words
            try:
                src_ind = ctx.src_word2ind[src]
                trg_ind = ctx.trg_word2ind[trg]
                ctx.validation[src_ind].add(trg_ind)
                vocab.add(src)
            except KeyError:
                oov.add(src)
        oov -= vocab  # If one of the translation options is in the vocabulary, then the entry is not an oov
        ctx.validation_coverage = len(ctx.validation) / (len(ctx.validation) + len(oov))
    
    return src_indices, trg_indices

def update_embedding_mapping(args, ctx, src_indices, xw, trg_indices, zw):
    # Update the embedding mapping
    xp = get_array_module(ctx.x)
    mapping = None
    if args.orthogonal or not ctx.end:  # orthogonal mapping
        u, s, vt = xp.linalg.svd(ctx.z[trg_indices].T.dot(ctx.x[src_indices]))
        w = vt.T.dot(u.T)
        mapping = w
        ctx.x.dot(w, out=xw)
        zw[:] = ctx.z
    elif args.unconstrained:  # unconstrained mapping
        x_pseudoinv = xp.linalg.inv(ctx.x[src_indices].T.dot(ctx.x[src_indices])).dot(ctx.x[src_indices].T)
        w = x_pseudoinv.dot(ctx.z[trg_indices])
        mapping = w
        ctx.x.dot(w, out=xw)
        zw[:] = ctx.z
    else:  # advanced mapping

        # TODO xw.dot(wx2, out=xw) and alike not working
        xw[:] = ctx.x
        zw[:] = ctx.z

        # STEP 1: Whitening
        def whitening_transformation(m):
            u, s, vt = xp.linalg.svd(m, full_matrices=False)
            return vt.T.dot(xp.diag(1/s)).dot(vt)
        if args.whiten:
            wx1 = whitening_transformation(xw[src_indices])
            wz1 = whitening_transformation(zw[trg_indices])
            xw = xw.dot(wx1)
            zw = zw.dot(wz1)

        # STEP 2: Orthogonal mapping
        wx2, s, wz2_t = xp.linalg.svd(xw[src_indices].T.dot(zw[trg_indices]))
        wz2 = wz2_t.T
        xw = xw.dot(wx2)
        zw = zw.dot(wz2)

        # STEP 3: Re-weighting
        xw *= s**args.src_reweight
        zw *= s**args.trg_reweight

        # STEP 4: De-whitening
        if args.src_dewhiten == 'src':
            xw = xw.dot(wx2.T.dot(xp.linalg.inv(wx1)).dot(wx2))
        elif args.src_dewhiten == 'trg':
            xw = xw.dot(wz2.T.dot(xp.linalg.inv(wz1)).dot(wz2))
        if args.trg_dewhiten == 'src':
            zw = zw.dot(wx2.T.dot(xp.linalg.inv(wx1)).dot(wx2))
        elif args.trg_dewhiten == 'trg':
            zw = zw.dot(wz2.T.dot(xp.linalg.inv(wz1)).dot(wz2))

        # STEP 5: Dimensionality reduction
        if args.dim_reduction > 0:
            xw = xw[:, :args.dim_reduction]
            zw = zw[:, :args.dim_reduction]
        
    
    return xw, zw, mapping

def build_dictionary(args, ctx, xw, zw, keep_prob = 1.):
    xp = get_array_module(xw)
    src_size, trg_size, dtype = ctx.src_size, ctx.trg_size, ctx.dtype
    best_sim_forward = xp.full(src_size, -100, dtype=dtype)
    src_indices_forward = xp.arange(src_size)
    trg_indices_forward = xp.zeros(src_size, dtype=int)
    best_sim_backward = xp.full(trg_size, -100, dtype=dtype)
    src_indices_backward = xp.zeros(trg_size, dtype=int)
    trg_indices_backward = xp.arange(trg_size)
    knn_sim_fwd = xp.zeros(src_size, dtype=dtype)
    knn_sim_bwd = xp.zeros(trg_size, dtype=dtype)

    simfwd = xp.empty((args.batch_size, trg_size), dtype=dtype)
    simbwd = xp.empty((args.batch_size, src_size), dtype=dtype)

    # Update the training dictionary
    if args.direction in ('forward', 'union'):
        if args.csls_neighborhood > 0:
            for i in range(0, trg_size, simbwd.shape[0]):
                j = min(i + simbwd.shape[0], trg_size)
                zw[i:j].dot(xw[:src_size].T, out=simbwd[:j-i])
                knn_sim_bwd[i:j] = topk_mean(simbwd[:j-i], k=args.csls_neighborhood, inplace=True)
        for i in range(0, src_size, simfwd.shape[0]):
            j = min(i + simfwd.shape[0], src_size)
            xw[i:j].dot(zw[:trg_size].T, out=simfwd[:j-i])
            simfwd[:j-i].max(axis=1, out=best_sim_forward[i:j])
            simfwd[:j-i] -= knn_sim_bwd/2  # Equivalent to the real CSLS scores for NN
            dropout(simfwd[:j-i], 1 - keep_prob).argmax(axis=1, out=trg_indices_forward[i:j])
    if args.direction in ('backward', 'union'):
        if args.csls_neighborhood > 0:
            for i in range(0, src_size, simfwd.shape[0]):
                j = min(i + simfwd.shape[0], src_size)
                xw[i:j].dot(zw[:trg_size].T, out=simfwd[:j-i])
                knn_sim_fwd[i:j] = topk_mean(simfwd[:j-i], k=args.csls_neighborhood, inplace=True)
        for i in range(0, trg_size, simbwd.shape[0]):
            j = min(i + simbwd.shape[0], trg_size)
            zw[i:j].dot(xw[:src_size].T, out=simbwd[:j-i])
            simbwd[:j-i].max(axis=1, out=best_sim_backward[i:j])
            simbwd[:j-i] -= knn_sim_fwd/2  # Equivalent to the real CSLS scores for NN
            dropout(simbwd[:j-i], 1 - keep_prob).argmax(axis=1, out=src_indices_backward[i:j])
    if args.direction == 'forward':
        src_indices = src_indices_forward
        trg_indices = trg_indices_forward
        objective = xp.mean(best_sim_forward).tolist()
    elif args.direction == 'backward':
        src_indices = src_indices_backward
        trg_indices = trg_indices_backward
        objective = xp.mean(best_sim_backward).tolist()
    elif args.direction == 'union':
        src_indices = xp.concatenate((src_indices_forward, src_indices_backward))
        trg_indices = xp.concatenate((trg_indices_forward, trg_indices_backward))
        objective = (xp.mean(best_sim_forward) + xp.mean(best_sim_backward)).tolist() / 2

    return objective, src_indices, trg_indices

def evaluating_csls_forward(args, xw, zw, src_indices, trg_indices, simval, validation, it, t, objective, keep_prob, validation_coverage):
    if args.validation is not None:
        src = list(validation.keys())
        xw[src].dot(zw.T, out=simval)
        accuracy = np.mean([1 if int(trg_indices[i]) in validation[i] else 0 for i in src])
        similarity = np.mean([max([simval[i, j].tolist() for j in validation[src[i]]]) for i in range(len(src))])

    # Logging
    duration = time.time() - t
    if args.verbose:
        eval_info = ""
        eval_info += 'ITERATION {0} ({1:.2f}s)'.format(it, duration)
        eval_info += '\t- Objective:        {0:9.4f}%'.format(100 * objective)
        eval_info += '\t- Drop probability: {0:9.4f}%'.format(100 - 100*keep_prob)
        if args.validation is not None:
            eval_info += '\t- Val. similarity:  {0:9.4f}%'.format(100 * similarity)
            eval_info += '\t- Val. accuracy:    {0:9.4f}%'.format(100 * accuracy)
            eval_info += '\t- Val. coverage:    {0:9.4f}%'.format(100 * validation_coverage)
        logger.info(eval_info)
    # if args.log is not None:
    #     val = '{0:.6f}\t{1:.6f}\t{2:.6f}'.format(
    #         100 * similarity, 100 * accuracy, 100 * validation_coverage) if args.validation is not None else ''
    #     print('{0}\t{1:.6f}\t{2}\t{3:.6f}'.format(it, 100 * objective, val, duration), file=log)
    #     log.flush()

def evaluating_nn(args,ctx, xw, zw, it, t, objective, keep_prob):
    if args.validation is not None:
        src = list(ctx.validation.keys())
        xw[src].dot(zw.T, out=ctx.simval)
        nn = asnumpy(ctx.simval.argmax(axis=1))
        accuracy = np.mean([1 if nn[i] in ctx.validation[src[i]] else 0 for i in range(len(src))])
        similarity = np.mean([max([ctx.simval[i, j].tolist() for j in ctx.validation[src[i]]]) for i in range(len(src))])

    # Logging
    duration = time.time() - t
    if args.verbose:
        eval_info = ""
        eval_info += 'ITER {0} ({1:.2f}s)'.format(it, duration)
        eval_info += '\t- Obj:        {0:9.4f}%'.format(100 * objective)
        eval_info += '\t- Drop prob: {0:9.4f}%'.format(100 - 100*keep_prob)
        if args.validation is not None:
            eval_info += '\t- Val.sim:  {0:9.4f}%'.format(100 * similarity)
            eval_info += '\t- Val.acc:    {0:9.4f}%'.format(100 * accuracy)
            eval_info += '\t- Val.coverage:    {0:9.4f}%'.format(100 * ctx.validation_coverage)
        logger.info(eval_info)
    # if args.log is not None:
    #     val = '{0:.6f}\t{1:.6f}\t{2:.6f}'.format(
    #         100 * similarity, 100 * accuracy, 100 * validation_coverage) if args.validation is not None else ''
    #     logger.info(('{0}\t{1:.6f}\t{2}\t{3:.6f}'.format(it, 100 * objective, val, duration), file=log)
    #     log.flush()

def main(args = None):
    ctx = argparse.Namespace()
    if not args:
        args = parsing()

    # Create log file
    if args.log:
        logger = create_logger(args.log)

    logger.info('============ The training args============')
    logger.info('\n'.join('%s: %s' % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))

    src_indices, trg_indices = dataload(args, ctx)
    
    # Allocate memory
    xw = ctx.xp.empty_like(ctx.x)
    zw = ctx.xp.empty_like(ctx.z)
    ctx.src_size = ctx.x.shape[0] if args.vocabulary_cutoff <= 0 else min(ctx.x.shape[0], args.vocabulary_cutoff)
    ctx.trg_size = ctx.z.shape[0] if args.vocabulary_cutoff <= 0 else min(ctx.z.shape[0], args.vocabulary_cutoff)
    if args.validation is not None:
        ctx.simval = ctx.xp.empty((len(ctx.validation.keys()), ctx.z.shape[0]), dtype=ctx.dtype)

    # Evaluating pretrained vector
    # Can alse been seen as seed dictionary
    # objective, src_indices, trg_indices =\
    #     build_dictionary(args, x[:src_size], z[:trg_size], dtype)
    # evaluating_nn(args, x[:src_size], z[:trg_size], simval, validation, -1, time.time(), objective, 0,\
    #     validation_coverage, log)

    # Training loop
    best_objective = objective = -100.
    it = 1
    last_improvement = 0
    keep_prob = args.stochastic_initial
    ctx.end = not args.self_learning
    final_mapping = None
    while True:

        # Increase the keep probability if we have not improve in args.stochastic_interval iterations
        t = time.time()
        if it - last_improvement > args.stochastic_interval:
            if keep_prob >= 1.0:
                ctx.end = True
            keep_prob = min(1.0, args.stochastic_multiplier*keep_prob)
            last_improvement = it

        xw, zw, _mapping = update_embedding_mapping(args, ctx, \
            src_indices, xw, trg_indices, zw)

        if _mapping is not None:
            final_mapping = _mapping
        # Self-learning
        if ctx.end:
            break

        objective, src_indices, trg_indices =\
            build_dictionary(args, ctx, xw, zw, keep_prob=keep_prob)
        if objective - best_objective >= args.threshold:
            last_improvement = it
            best_objective = objective
        # Accuracy and similarity evaluation in validation
        evaluating_nn(args, ctx, xw, zw, it, t, objective, keep_prob)

        it += 1

    # # Write mapped embeddings
    srcfile = open(args.src_output, mode='w', encoding=args.encoding, errors='surrogateescape')
    trgfile = open(args.trg_output, mode='w', encoding=args.encoding, errors='surrogateescape')
    embeddings.write(ctx.src_words, xw, srcfile)
    embeddings.write(ctx.trg_words, zw, trgfile)
    srcfile.close()
    trgfile.close()
    
    return final_mapping


if __name__ == '__main__':
    main()